{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This script is for computing the performance of the network on our sketchy benchmark. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pylab import *\n",
    "%matplotlib inline\n",
    "import os\n",
    "import sys"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## caffe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we need to import caffe. You'll need to have caffe installed, as well as python interface for caffe. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#TODO: specify your caffe root folder here\n",
    "caffe_root = \"X:\\caffe_siggraph/caffe-windows-master\"\n",
    "sys.path.insert(0, caffe_root+'/python')\n",
    "import caffe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can load up the network. You can change the path to your own network here. Make sure to use the matching deploy prototxt files and change the target layer to your layer name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "#TODO: change to your own network and deploying file\n",
    "PRETRAINED_FILE = '../models/triplet_googlenet/triplet_googlenet_finegrain_final.caffemodel' \n",
    "sketch_model = '../models/triplet_googlenet/googlenet_sketchdeploy.prototxt'\n",
    "image_model = '../models/triplet_googlenet/googlenet_imagedeploy.prototxt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['data',\n",
       " 'conv1/7x7_s2_s',\n",
       " 'pool1/3x3_s2_s',\n",
       " 'pool1/norm1_s',\n",
       " 'conv2/3x3_reduce_s',\n",
       " 'conv2/3x3_s',\n",
       " 'conv2/norm2_s',\n",
       " 'pool2/3x3_s2_s',\n",
       " 'pool2/3x3_s2_s_pool2/3x3_s2_s_0_split_0',\n",
       " 'pool2/3x3_s2_s_pool2/3x3_s2_s_0_split_1',\n",
       " 'pool2/3x3_s2_s_pool2/3x3_s2_s_0_split_2',\n",
       " 'pool2/3x3_s2_s_pool2/3x3_s2_s_0_split_3',\n",
       " 'inception_3a/1x1_s',\n",
       " 'inception_3a/3x3_reduce_s',\n",
       " 'inception_3a/3x3_s',\n",
       " 'inception_3a/5x5_reduce_s',\n",
       " 'inception_3a/5x5_s',\n",
       " 'inception_3a/pool_s',\n",
       " 'inception_3a/pool_proj_s',\n",
       " 'inception_3a/output_s',\n",
       " 'inception_3a/output_s_inception_3a/output_s_0_split_0',\n",
       " 'inception_3a/output_s_inception_3a/output_s_0_split_1',\n",
       " 'inception_3a/output_s_inception_3a/output_s_0_split_2',\n",
       " 'inception_3a/output_s_inception_3a/output_s_0_split_3',\n",
       " 'inception_3b/1x1_s',\n",
       " 'inception_3b/3x3_reduce_s',\n",
       " 'inception_3b/3x3_s',\n",
       " 'inception_3b/5x5_reduce_s',\n",
       " 'inception_3b/5x5_s',\n",
       " 'inception_3b/pool_s',\n",
       " 'inception_3b/pool_proj_s',\n",
       " 'inception_3b/output_s',\n",
       " 'pool3/3x3_s2_s',\n",
       " 'pool3/3x3_s2_s_pool3/3x3_s2_s_0_split_0',\n",
       " 'pool3/3x3_s2_s_pool3/3x3_s2_s_0_split_1',\n",
       " 'pool3/3x3_s2_s_pool3/3x3_s2_s_0_split_2',\n",
       " 'pool3/3x3_s2_s_pool3/3x3_s2_s_0_split_3',\n",
       " 'inception_4a/1x1_s',\n",
       " 'inception_4a/3x3_reduce_s',\n",
       " 'inception_4a/3x3_s',\n",
       " 'inception_4a/5x5_reduce_s',\n",
       " 'inception_4a/5x5_s',\n",
       " 'inception_4a/pool_s',\n",
       " 'inception_4a/pool_proj_s',\n",
       " 'inception_4a/output_s',\n",
       " 'inception_4a/output_s_inception_4a/output_s_0_split_0',\n",
       " 'inception_4a/output_s_inception_4a/output_s_0_split_1',\n",
       " 'inception_4a/output_s_inception_4a/output_s_0_split_2',\n",
       " 'inception_4a/output_s_inception_4a/output_s_0_split_3',\n",
       " 'inception_4a/output_s_inception_4a/output_s_0_split_4',\n",
       " 'loss1/ave_pool_s',\n",
       " 'loss1/conv_s',\n",
       " 'loss1/fc_s',\n",
       " 'inception_4b/1x1_s',\n",
       " 'inception_4b/3x3_reduce_s',\n",
       " 'inception_4b/3x3_s',\n",
       " 'inception_4b/5x5_reduce_s',\n",
       " 'inception_4b/5x5_s',\n",
       " 'inception_4b/pool_s',\n",
       " 'inception_4b/pool_proj_s',\n",
       " 'inception_4b/output_s',\n",
       " 'inception_4b/output_s_inception_4b/output_s_0_split_0',\n",
       " 'inception_4b/output_s_inception_4b/output_s_0_split_1',\n",
       " 'inception_4b/output_s_inception_4b/output_s_0_split_2',\n",
       " 'inception_4b/output_s_inception_4b/output_s_0_split_3',\n",
       " 'inception_4c/1x1_s',\n",
       " 'inception_4c/3x3_reduce_s',\n",
       " 'inception_4c/3x3_s',\n",
       " 'inception_4c/5x5_reduce_s',\n",
       " 'inception_4c/5x5_s',\n",
       " 'inception_4c/pool_s',\n",
       " 'inception_4c/pool_proj_s',\n",
       " 'inception_4c/output_s',\n",
       " 'inception_4c/output_s_inception_4c/output_s_0_split_0',\n",
       " 'inception_4c/output_s_inception_4c/output_s_0_split_1',\n",
       " 'inception_4c/output_s_inception_4c/output_s_0_split_2',\n",
       " 'inception_4c/output_s_inception_4c/output_s_0_split_3',\n",
       " 'inception_4d/1x1_s',\n",
       " 'inception_4d/3x3_reduce_s',\n",
       " 'inception_4d/3x3_s',\n",
       " 'inception_4d/5x5_reduce_s',\n",
       " 'inception_4d/5x5_s',\n",
       " 'inception_4d/pool_s',\n",
       " 'inception_4d/pool_proj_s',\n",
       " 'inception_4d/output_s',\n",
       " 'inception_4d/output_s_inception_4d/output_s_0_split_0',\n",
       " 'inception_4d/output_s_inception_4d/output_s_0_split_1',\n",
       " 'inception_4d/output_s_inception_4d/output_s_0_split_2',\n",
       " 'inception_4d/output_s_inception_4d/output_s_0_split_3',\n",
       " 'inception_4d/output_s_inception_4d/output_s_0_split_4',\n",
       " 'loss2/ave_pool_s',\n",
       " 'loss2/conv_s',\n",
       " 'loss2/fc_s',\n",
       " 'inception_4e/1x1_s',\n",
       " 'inception_4e/3x3_reduce_s',\n",
       " 'inception_4e/3x3_s',\n",
       " 'inception_4e/5x5_reduce_s',\n",
       " 'inception_4e/5x5_s',\n",
       " 'inception_4e/pool_s',\n",
       " 'inception_4e/pool_proj_s',\n",
       " 'inception_4e/output_s',\n",
       " 'pool4/3x3_s2_s',\n",
       " 'pool4/3x3_s2_s_pool4/3x3_s2_s_0_split_0',\n",
       " 'pool4/3x3_s2_s_pool4/3x3_s2_s_0_split_1',\n",
       " 'pool4/3x3_s2_s_pool4/3x3_s2_s_0_split_2',\n",
       " 'pool4/3x3_s2_s_pool4/3x3_s2_s_0_split_3',\n",
       " 'inception_5a/1x1_s',\n",
       " 'inception_5a/3x3_reduce_s',\n",
       " 'inception_5a/3x3_s',\n",
       " 'inception_5a/5x5_reduce_s',\n",
       " 'inception_5a/5x5_s',\n",
       " 'inception_5a/pool_s',\n",
       " 'inception_5a/pool_proj_s',\n",
       " 'inception_5a/output_s',\n",
       " 'inception_5a/output_s_inception_5a/output_s_0_split_0',\n",
       " 'inception_5a/output_s_inception_5a/output_s_0_split_1',\n",
       " 'inception_5a/output_s_inception_5a/output_s_0_split_2',\n",
       " 'inception_5a/output_s_inception_5a/output_s_0_split_3',\n",
       " 'inception_5b/1x1_s',\n",
       " 'inception_5b/3x3_reduce_s',\n",
       " 'inception_5b/3x3_s',\n",
       " 'inception_5b/5x5_reduce_s',\n",
       " 'inception_5b/5x5_s',\n",
       " 'inception_5b/pool_s',\n",
       " 'inception_5b/pool_proj_s',\n",
       " 'inception_5b/output_s',\n",
       " 'pool5/7x7_s1_s']"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "caffe.set_mode_gpu()\n",
    "#caffe.set_mode_cpu()\n",
    "sketch_net = caffe.Net(sketch_model, PRETRAINED_FILE, caffe.TEST)\n",
    "img_net = caffe.Net(image_model, PRETRAINED_FILE, caffe.TEST)\n",
    "sketch_net.blobs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "#TODO: set output layer name. You can use sketch_net.blobs.keys() to list all layer\n",
    "output_layer_sketch = 'pool5/7x7_s1_s'\n",
    "output_layer_image = 'pool5/7x7_s1_p'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#set the transformer\n",
    "transformer = caffe.io.Transformer({'data': np.shape(sketch_net.blobs['data'].data)})\n",
    "transformer.set_mean('data', np.array([104, 117, 123]))\n",
    "transformer.set_transpose('data',(2,0,1))\n",
    "transformer.set_channel_swap('data', (2,1,0))\n",
    "transformer.set_raw_scale('data', 255.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sketchy test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#photo paths\n",
    "photo_paths = 'C:\\Users\\Patsorn\\Documents/notebook_backup/SBIR/photos/'\n",
    "sketch_paths = 'C:\\Users\\Patsorn\\Documents/notebook_backup/SBIR/sketches/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#load up test images\n",
    "with open('../list/test_img_list.txt','r') as my_file:\n",
    "    test_img_list = [c.rstrip() for c in my_file.readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1250/1250 Extracting sailboat/n04128499_654.jpg... done\n"
     ]
    }
   ],
   "source": [
    "#extract feature for all test images\n",
    "feats = []\n",
    "N = np.shape(test_img_list)[0]\n",
    "for i,path in enumerate(test_img_list):\n",
    "    imgname = path.split('/')[-1]\n",
    "    imgname = imgname.split('.jpg')[0]\n",
    "    imgcat = path.split('/')[0]\n",
    "    print '\\r',str(i+1)+'/'+str(N)+ ' '+'Extracting ' +path+'...',\n",
    "    full_path = photo_paths + path\n",
    "    img = (transformer.preprocess('data', caffe.io.load_image(full_path.rstrip())))\n",
    "    img_in = np.reshape([img],np.shape(sketch_net.blobs['data'].data))\n",
    "    out_img = img_net.forward(data=img_in)\n",
    "    out_img = np.copy(out_img[output_layer_image]) \n",
    "    feats.append(out_img)\n",
    "    print 'done',\n",
    "np.shape(feats)\n",
    "feats = np.resize(feats,[np.shape(feats)[0],np.shape(feats)[2]])  #quick fixed for size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#build nn pool\n",
    "from sklearn.neighbors import NearestNeighbors,LSHForest\n",
    "nbrs  = NearestNeighbors(n_neighbors=np.size(feats,0), algorithm='brute',metric='cosine').fit(feats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ranking: sailboat n04128499_654-5-5.png found at 0 \n",
      "Recall @K=1 =  0.371039290241\n"
     ]
    }
   ],
   "source": [
    "#compute score\n",
    "\n",
    "num_query = 0\n",
    "count_recall = [0]*1250\n",
    "sum_rank = 0\n",
    "sum_class_rank = [0]*125\n",
    "count_recall_class = np.zeros((125,1250),np.float)\n",
    "i_coco =-1\n",
    "for i,img in enumerate(test_img_list):\n",
    "    imgname = img.split('/')[-1]\n",
    "    imgname = imgname.split('.jpg')[0]\n",
    "    imgcat = img.split('/')[0]\n",
    "    \n",
    "    sketch_list = os.listdir(sketch_paths+imgcat)\n",
    "    sketch_img_list = [skg for skg in sketch_list if skg.startswith(imgname+'-') and skg.endswith('-5.png')]#change this skg.endswith('-1.png') to the variation you want\n",
    "    for sketch in sketch_img_list:\n",
    "        sketch_path = sketch_paths + imgcat+'/' + sketch\n",
    "        sketch_in = (transformer.preprocess('data', plt.imread(sketch_path)))\n",
    "        sketch_in = np.reshape([sketch_in],np.shape(sketch_net.blobs['data'].data))\n",
    "        query = sketch_net.forward(data=sketch_in)\n",
    "        query=np.copy(query[output_layer_sketch])\n",
    "        distances, indices = nbrs.kneighbors(np.reshape(query,[np.shape(query)[1]]))\n",
    "        num_query = num_query+1\n",
    "        print '\\r','...'+sketch+'...',\n",
    "\n",
    "        for j,indice in enumerate(indices[0]):\n",
    "            if indice==i:\n",
    "                #this j is the right one.\n",
    "                count_recall[j] = count_recall[j]+1\n",
    "                print '\\r','ranking: '+imgcat+ ' '+sketch  + ' found at '  +str(j),\n",
    "                break\n",
    "                \n",
    "cum_count = [0]*1250\n",
    "sumc = 0\n",
    "for i,c in enumerate(count_recall):\n",
    "    sumc = sumc + c\n",
    "    cum_count[i] = sumc\n",
    "print '\\nRecall @K=1 = ', 1.00*cum_count[0]/cum_count[-1]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
